"""
Probe K/L sign flip: why does the demographic coefficient on delta_log_kl
flip sign across subsamples, and why does rule_of_law x Z reverse?
"""

import sys
from pathlib import Path
import numpy as np
import pandas as pd

PROJECT_DIR = Path("/mnt/c/demographics_capital_flows/capital_deepening")
ROOT_DIR = PROJECT_DIR.parent
sys.path.insert(0, str(ROOT_DIR / "multilateral" / "src"))
from model import PanelGLS

df = pd.read_csv(PROJECT_DIR / "data" / "processed" / "deepening_panel.csv")

# ---------- setup ----------
OECD_38 = [
    "AUS","AUT","BEL","CAN","CHL","COL","CRI","CZE","DNK","EST","FIN","FRA",
    "DEU","GRC","HUN","ISL","IRL","ISR","ITA","JPN","KOR","LVA","LTU","LUX",
    "MEX","NLD","NZL","NOR","POL","PRT","SVK","SVN","ESP","SWE","CHE","TUR",
    "GBR","USA"
]

df["oecd"] = df["iso3"].isin(OECD_38).astype(int)

CONTROLS = ["fiscal_bal_gdp", "nfa_gdp_lag", "log_rel_opw", "kaopen"]
Z_VARS = ["Z_1", "Z_2", "Z_3"]
ALL_RHS = Z_VARS + CONTROLS

print("=" * 80)
print("COLUMNS IN PANEL (sorted):")
print("=" * 80)
print(", ".join(sorted(df.columns)))
print(f"\nPanel shape: {df.shape}")
print(f"Year range: {df['year'].min()} - {df['year'].max()}")
print(f"Countries: {df['iso3'].nunique()}")
print()


def run_reg(data, dep_var, rhs_vars, label=""):
    cols = [dep_var] + rhs_vars + ["iso3", "year"]
    sub = data[cols].dropna()
    if len(sub) < 20:
        print(f"  [{label}] Insufficient obs: {len(sub)}")
        return None
    y = sub[dep_var].values
    X = sub[rhs_vars].values
    entity = sub["iso3"].values
    time = sub["year"].values
    gls = PanelGLS()
    gls.fit(y, X, entity, time)
    return {
        "label": label, "dep_var": dep_var,
        "n_obs": gls.n_obs, "n_countries": gls.n_countries,
        "r_squared": gls.r_squared, "vars": rhs_vars,
        "beta": gls.beta, "se": gls.se, "pvalues": gls.pvalues,
    }


def print_results(res, show_controls=False):
    if res is None:
        return
    print(f"  {res['label']}")
    print(f"  N={res['n_obs']}, Countries={res['n_countries']}, R2={res['r_squared']:.4f}")
    for i, v in enumerate(res["vars"]):
        if not show_controls and v in CONTROLS:
            continue
        p = res["pvalues"][i]
        stars = "***" if p < 0.01 else "**" if p < 0.05 else "*" if p < 0.10 else ""
        print(f"    {v:30s}  b={res['beta'][i]:10.4f}  SE={res['se'][i]:8.4f}  p={p:.4f}{stars}")
    print()


def section(title):
    print("=" * 80)
    print(f"  {title}")
    print("=" * 80)


all_results = []

# 1. Full sample
section("1. FULL SAMPLE: Z -> delta_log_kl")
r = run_reg(df, "delta_log_kl", ALL_RHS, "Full sample")
print_results(r, show_controls=True)
if r: all_results.append(r)

# 2. OECD only
section("2. OECD ONLY: Z -> delta_log_kl")
r = run_reg(df[df["oecd"] == 1], "delta_log_kl", ALL_RHS, "OECD only")
print_results(r, show_controls=True)
if r: all_results.append(r)

# 3. Non-OECD
section("3. NON-OECD: Z -> delta_log_kl")
r = run_reg(df[df["oecd"] == 0], "delta_log_kl", ALL_RHS, "Non-OECD only")
print_results(r, show_controls=True)
if r: all_results.append(r)

# 4. Income terciles
section("4. INCOME TERCILES: Z -> delta_log_kl")
df_valid = df.dropna(subset=["gdp_pc_ppp", "delta_log_kl"])
tercile_cuts = df_valid["gdp_pc_ppp"].quantile([1/3, 2/3])
t1, t2 = tercile_cuts.iloc[0], tercile_cuts.iloc[1]
print(f"  GDP/capita tercile boundaries: ${t1:,.0f} | ${t2:,.0f}")
print()

for label, mask in [
    ("Low income (bottom tercile)", df["gdp_pc_ppp"] <= t1),
    ("Middle income", (df["gdp_pc_ppp"] > t1) & (df["gdp_pc_ppp"] <= t2)),
    ("High income (top tercile)", df["gdp_pc_ppp"] > t2),
]:
    r = run_reg(df[mask], "delta_log_kl", ALL_RHS, label)
    print_results(r)
    if r: all_results.append(r)

# 5. Rule of law interaction
section("5. RULE OF LAW INTERACTION: Z1 x rule_of_law -> delta_log_kl")
df["Z1_x_rol"] = df["Z_1"] * df["rule_of_law"]
ROL_RHS = Z_VARS + ["Z1_x_rol", "rule_of_law"] + CONTROLS

for label, subset in [
    ("Full sample (RoL)", df),
    ("OECD (RoL)", df[df["oecd"] == 1]),
    ("Non-OECD (RoL)", df[df["oecd"] == 0]),
]:
    r = run_reg(subset, "delta_log_kl", ROL_RHS, label)
    print_results(r)
    if r: all_results.append(r)

# 6. FDI -> K/L
section("6. FDI -> delta_log_kl")
fdi_candidates = [c for c in df.columns if "fdi" in c.lower()]
print(f"  FDI variables available: {fdi_candidates}")
print(f"  fdi_liab_gdp non-null: {df['fdi_liab_gdp'].notna().sum()}")
print()

FDI_RHS_liab = ["fdi_liab_gdp"] + CONTROLS
FDI_RHS_full = ["fdi_liab_gdp"] + Z_VARS + CONTROLS

for label, subset, rhs in [
    ("Full: FDI only", df, FDI_RHS_liab),
    ("Full: FDI + Z", df, FDI_RHS_full),
    ("OECD: FDI + Z", df[df["oecd"] == 1], FDI_RHS_full),
    ("Non-OECD: FDI + Z", df[df["oecd"] == 0], FDI_RHS_full),
]:
    r = run_reg(subset, "delta_log_kl", rhs, label)
    print_results(r)
    if r: all_results.append(r)

# 7. I/Y comparison
section("7. COMPARISON: Z -> gross_fixed_investment_gdp (I/Y)")
iy_var = "gross_fixed_investment_gdp"
print(f"  {iy_var} non-null: {df[iy_var].notna().sum()}")
print()

for label, subset in [
    ("Full sample (I/Y)", df),
    ("OECD (I/Y)", df[df["oecd"] == 1]),
    ("Non-OECD (I/Y)", df[df["oecd"] == 0]),
]:
    r = run_reg(subset, iy_var, ALL_RHS, label)
    print_results(r)
    if r: all_results.append(r)

print("  --- Income terciles on I/Y ---")
for label, mask in [
    ("Low income (I/Y)", df["gdp_pc_ppp"] <= t1),
    ("Middle income (I/Y)", (df["gdp_pc_ppp"] > t1) & (df["gdp_pc_ppp"] <= t2)),
    ("High income (I/Y)", df["gdp_pc_ppp"] > t2),
]:
    r = run_reg(df[mask], iy_var, ALL_RHS, label)
    print_results(r)
    if r: all_results.append(r)

# Diagnostics
section("DIAGNOSTIC: Sample composition")
for label, mask in [
    ("Full", pd.Series(True, index=df.index)),
    ("OECD", df["oecd"] == 1),
    ("Non-OECD", df["oecd"] == 0),
]:
    sub = df[mask].dropna(subset=["delta_log_kl"] + ALL_RHS)
    print(f"  {label:12s}: N={len(sub):5d}, countries={sub['iso3'].nunique():3d}, "
          f"mean(delta_log_kl)={sub['delta_log_kl'].mean():.5f}, "
          f"std={sub['delta_log_kl'].std():.5f}, "
          f"mean(Z_1)={sub['Z_1'].mean():.4f}")
print()

# Write markdown
section("WRITING MARKDOWN")
out_path = PROJECT_DIR / "output" / "tables" / "kl_sign_flip_probe.md"
out_path.parent.mkdir(parents=True, exist_ok=True)

lines = []
lines.append("# K/L Sign Flip Probe\n")
lines.append("Investigation of why the demographic coefficient on delta-log(K/L) flips sign across subsamples.\n")

def fmt_table(results_list, highlight_vars=None):
    if highlight_vars is None:
        highlight_vars = Z_VARS + ["Z1_x_rol", "rule_of_law", "fdi_liab_gdp"]
    rows = []
    header_vars = []
    for r in results_list:
        for v in r["vars"]:
            if v in highlight_vars and v not in header_vars:
                header_vars.append(v)
    hdr = "| Specification | N | Countries | R2 |"
    sep = "|---|---|---|---|"
    for v in header_vars:
        hdr += f" {v} |"
        sep += "---|"
    rows.append(hdr)
    rows.append(sep)
    for r in results_list:
        row = f"| {r['label']} | {r['n_obs']} | {r['n_countries']} | {r['r_squared']:.3f} |"
        for v in header_vars:
            if v in r["vars"]:
                idx = r["vars"].index(v)
                b = r["beta"][idx]
                p = r["pvalues"][idx]
                stars = "***" if p < 0.01 else "**" if p < 0.05 else "*" if p < 0.10 else ""
                row += f" {b:.3f}{stars} ({p:.3f}) |"
            else:
                row += " -- |"
        rows.append(row)
    return "\n".join(rows)

lines.append("## 1. K/L Regressions by Sample\n")
lines.append("Dependent variable: delta-log(K/L). Controls: fiscal_bal_gdp, nfa_gdp_lag, log_rel_opw, kaopen.\n")
sec1 = [r for r in all_results if r["label"] in ["Full sample", "OECD only", "Non-OECD only"]]
lines.append(fmt_table(sec1))
lines.append("")

lines.append("## 2. Income Terciles\n")
sec4 = [r for r in all_results if "tercile" in r["label"] or r["label"] == "Middle income"]
lines.append(fmt_table(sec4))
lines.append("")

lines.append("## 3. Rule of Law Interaction\n")
sec5 = [r for r in all_results if "RoL" in r["label"]]
lines.append(fmt_table(sec5))
lines.append("")

lines.append("## 4. FDI -> K/L\n")
sec6 = [r for r in all_results if "FDI" in r["label"]]
lines.append(fmt_table(sec6))
lines.append("")

lines.append("## 5. Comparison: Z -> I/Y (gross_fixed_investment_gdp)\n")
lines.append("Same specifications but on investment/GDP instead of delta-log(K/L).\n")
sec7 = [r for r in all_results if "I/Y" in r["label"]]
lines.append(fmt_table(sec7))
lines.append("")

lines.append("## 6. Interpretation\n")
lines.append("Key findings from the sign-flip probe are summarized below.\n")
lines.append("*(Auto-generated -- review regression output above for full details.)*\n")

with open(out_path, "w") as f:
    f.write("\n".join(lines))

print(f"  Saved to: {out_path}")
print("  Done.")
